import copy
import pickle
import sys
import time
import random
import os
import numpy as np
import pandas as pd

import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras
#tf.config.set_per_process_memory_growth(True)

# tf.debugging.set_log_device_placement(True)

from keras.layers import Input, Dense, Conv2D, LeakyReLU, Dropout, Flatten, MaxPooling2D, GlobalAveragePooling2D
from tqdm import tqdm

from csmodels import xrayGenerator, xrayDiscriminator
from tfFunctionsUtils import apply_gumbel_softmax, map_fill_to_discrete, compare_conditionals_within
from tfFunctionsUtils import get_joint_distributions_from_samples, penalty_calculation
from sklearn.preprocessing import OneHotEncoder
from tfFunctionsUtils import calculate_TVD
from csxray_graph import Experiment
from tfFunctionsUtils import getdoKey
from tfFunctionsUtils import load_dataset
from csxray_graph import set_Xray
from tfFunctionsUtils import calculate_KL

from keras.optimizers import Adam

from trainXrayEncoder.ConvolutionalCondVAE import Encoder, ConvCVAE
from trainXrayEncoder.ConvolutionalCondVAE import Decoder
from trainXrayEncoder.multiple_disc_wdummy_wEncoder import covidGen, PneumGen, \
    Discriminator


from keras import Model
from tfFunctionsUtils import load_dataset
from csmodels import xrayGenerator

def get_generators(Exp, load_which_models):
    label_generators = {}
    optimizersMech = {}

    for label in Exp.Observed_DAG:
        noise_dims = Exp.CONF_NOISE_DIM

        if label=='covid_19':
            gen_dim=2
            label_generators[label] = covidGen(noise_dims, gen_dim)
            optimizersMech[label] =tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)


        if label=='pneum':
            disc = xrayDiscriminator(input_shape=(Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
            epoch = 288
            filename = f"/SaveDir/params_discriminator_epoch_{epoch}.hdf5"
            disc.load_weights(filename)
            disc.trainable= False
            new_model= disc

            condition_dim=3
            gen_dim=2
            pneum_model = PneumGen(condition_dim, noise_dims, gen_dim)
            label_generators[label] = [new_model, pneum_model]
            optimizersMech[label] =tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)


        if label=='Rxray':
            latent_dim = 128
            beta = 0.65

            encoder = Encoder(latent_dim)
            decoder = Decoder()
            model = ConvCVAE(
                encoder,
                decoder,
                label_dim=1,
                latent_dim=latent_dim,
                beta=beta,
                image_dim=[64,64, 3])

            # Checkpoint path
            # checkpoint_root = "./CVAE{}_{}_checkpoint".format(latent_dim, beta)
            checkpoint_root = "/trainXrayEncoder/CVAE128_0.65_checkpoint"
            checkpoint_name = "model"
            save_prefix = os.path.join(checkpoint_root, checkpoint_name)

            # Define the checkpoint
            checkpoint = tf.train.Checkpoint(module=model)

            ###
            # Restore the latest checkpoint
            latest = tf.train.latest_checkpoint(checkpoint_root)
            if latest is not None:
                checkpoint.restore(latest)
                print("Checkpoint restored:", latest)
            else:
                print("No checkpoint!")

            trunc_disc = model
            trunc_disc.trainable= False
            label_generators[label] = trunc_disc   # 64x 64x 3 -> 129


    return label_generators, optimizersMech


def get_discriminators(Exp):
    discriminatorsMech={}
    doptimizersMech={}


    rep_dim= 129 # latent dim =128 + label =1
    compare_dims= Exp.label_dim['covid_19'] + Exp.label_dim['pneum'] +  rep_dim
    discriminatorsMech['H2'] = Discriminator(compare_dims)
    doptimizersMech['H2'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)

    compare_dims=2
    discriminatorsMech['covid_19'] = Discriminator(compare_dims)
    doptimizersMech['covid_19'] = tf.keras.optimizers.Adam(Exp.learning_rate, beta_1=0.5, beta_2=0.9)


    return discriminatorsMech, doptimizersMech





def get_generated_labels(Exp, label_generators, intervened, batch_size, data_batch=[]):  #no recursion. uses the same noise

    gen_labels={}
    confNoises = tf.random.normal([batch_size, Exp.NOISE_DIM], mean=0.0, stddev=1.0,
                             dtype=tf.dtypes.float32)

    if 'covid_19' in intervened.keys():
        zeroes = tf.zeros([batch_size, Exp.label_dim['covid_19']])
        ones = tf.ones([batch_size, Exp.label_dim['covid_19']])

        if intervened['covid_19']==0:
            gen_labels['covid_19']= tf.concat(axis=1, values=[ones, zeroes])
        else:
            gen_labels['covid_19']= tf.concat(axis=1, values=[zeroes, ones])

    else:
        output = label_generators['covid_19']([confNoises], training=True)
        soft, hard= apply_gumbel_softmax(output, Exp.Temperature)
        gen_labels['covid_19']= soft

    #******xray starts
    #(1-C)(1+N) ; N\in[0,1]

    # (1-C)(1.5+TN) where T is discrete variable in {-0.5, 0.5}. N is a narrow Gaussian. T can also be in {-0.3,0.3}
    #equivalent: i0*(1.5 + TN) or (i0)*(1.5 + T+N)

    i0= tf.reshape(gen_labels['covid_19'][:,0], [-1,1])
    if len(data_batch)==0: # when no pneum data, use random
        T= random.choices( [-0.5, 0.5], weights=[0.5,0.5], k=batch_size)
        T= tf.reshape(T, [-1,1])
        N= tf.random.normal(shape=(batch_size,1), mean=0.0, stddev=0.0001, dtype=tf.float32)
        img_par= (i0)*(1.5 + T+N)  #makes more sense to me.
    else: #using real pneum data
        data_batch =data_batch[0]  #sent as a list
        n1= tf.reshape(data_batch[:,1], [-1,1])  #pneum =1 index
        n_ones= tf.ones([batch_size,1])
        n1= tf.cast(n1, tf.dtypes.float32)
        n_ones= tf.cast(n_ones, tf.dtypes.float32)
        img_par = (i0) * (n1+n_ones)  #i0* (n1+1)

    image_noise = np.random.uniform(-1, 1, (batch_size, 100))
    generated_images_batch = label_generators['xray']([image_noise, img_par])
    gen_labels['xray'] = generated_images_batch

    gen_labels['img_noise']= image_noise
    gen_labels['xrayInput'] = img_par

    # Rxray image  #takes image & label as input and outputs 129


    resized_images = tf.image.resize(gen_labels['xray'], (64, 64))

    input_img, input_label, conditional_input = label_generators['Rxray'].conditional_input([resized_images, img_par])
    encoded = label_generators['Rxray'].encoder(conditional_input, label_generators['Rxray'].latent_dim, is_train=False)
    z_mean, z_log_var = tf.split(encoded, num_or_size_splits=2, axis=1)
    z_cond = label_generators['Rxray'].reparametrization(z_mean, z_log_var, input_label)
    gen_labels['Rxray']= z_cond

    # pneum
    model1= label_generators['pneum'][0]
    model2= label_generators['pneum'][1]
    fk_rl, class3_labels = model1([gen_labels['xray']])  #getting image label from pre-trained model
    output = model2([class3_labels, confNoises], training=True)

    soft, hard= apply_gumbel_softmax(output, Exp.Temperature)
    gen_labels['pneum']= soft


    return gen_labels

# Label smoothing function which adds a small random value to the labels
# and we will use this function to smooth the binary labels for the discriminator
def label_smoothing(vector, max_dev = 0.2):
        d = max_dev * np.random.rand(vector.shape[0],vector.shape[1])
        if vector[0][0] == 0:
            return vector + d
        else:
            return vector - d



def soft_trainXray(Exp, G_fake_first, G_fake_second, image_batch, original_databatch,  dis, combined, batch_size):
    ### TRAIN the DISCRIMINATOR ###


    i0= tf.reshape(original_databatch[:,0], [-1,1])
    i1= tf.reshape(original_databatch[:,1], [-1,1])
    label_batch = (1-i0)* (i1+1)  # [(1,1),(1,0)] -> 0 , [(0,0)] -> 1, [(0,1)] ->2

    generated_images= G_fake_first['xray']
    sampled_labels= G_fake_first['xrayInput']


    X = np.concatenate((image_batch, generated_images))


    # Smoothed real/fake binary labels for the discriminator (real images ~ 1, fake images ~ 0)
    valid = label_smoothing(vector = np.ones((batch_size, 1)), max_dev = 0.2)
    fake = label_smoothing(vector = np.zeros((batch_size, 1)), max_dev = 0.2)
    # Concatenate the labels for the real and fake images
    y = np.concatenate((valid, fake), axis = 0)


    # Concatenate class labels for the real images and the fake images [0, 1, 2, 0, 1, 2, 0, 1, 2, ...]
    aux_y = np.concatenate((label_batch, sampled_labels), axis=0)

    # TRAIN the DISCRIMINATOR on the real and fake images (with corresponding y fake/real and aux_y class labels)
    # and append TRAINING DISCRIMINATOR LOSS
    dis_loss= dis.train_on_batch(X, [y, aux_y])


    ### TRAIN the GENERATOR ###

    # Generate a batch of noise and labels for the generator of 2*nb_train samples
    # noise, sampled_labels = generate_batch_noise_and_labels(2 * nb_train, latent_dim)
    noise= tf.concat(axis=0, values=[G_fake_first['img_noise'], G_fake_second['img_noise']])
    sampled_labels= tf.concat(axis=0, values=[G_fake_first['xrayInput'], G_fake_second['xrayInput']])


    # Trick the discriminator into thinking that the generated samples are real
    # by using the valid labels (1) as target for the discriminator
    trick = np.ones(2 * batch_size)

    # TRAIN the GENERATOR (discriminator parameters are frozen)
    # and appen TRAINING GENERATOR LOSS
    gen_loss= combined.train_on_batch([noise, tf.reshape(sampled_labels, [-1,1])], [trick, sampled_labels])


    return gen_loss, dis_loss



def calculate_joint(Exp, keep_G_fake):
    covid19 = keep_G_fake[:,0:2]
    print(covid19.shape)
    covid19= tf.math.argmax(covid19, axis=1)
    covid19= tf.reshape(covid19, [-1,1])

    pneum= keep_G_fake[:,2:4]
    print(pneum.shape)
    pneum= tf.math.argmax(pneum, axis=1)
    pneum= tf.reshape(pneum, [-1,1])

    joint= tf.concat(axis=1, values= [tf.cast(covid19, tf.int32) , tf.cast(pneum, tf.int32)  ])

    joint_prob= get_joint_distributions_from_samples(['covid_19','pneum'], [2,2], joint.numpy())
    covid_prob= get_joint_distributions_from_samples(['covid_19'], [2], covid19.numpy())
    pneum_prob= get_joint_distributions_from_samples(['pneum'], [2], pneum.numpy())


    # P(pneum|covid)
    cond_prob_list = compare_conditionals_within(Exp, joint.numpy(), ['pneum'], ['covid_19'], ['covid_19', 'pneum'])

    return joint_prob, cond_prob_list, covid_prob, pneum_prob


def do_train(Exp,  label_generators, discriminators,G_optimizers, D_optimizer, data_batch, image_batch):

    original_databatch= copy.deepcopy(data_batch)

    enc = OneHotEncoder()
    enc.fit(data_batch)
    data_batch = enc.transform(data_batch).toarray()

    # -----------------------------------------------------------------------------
    print('Training Generator')
    with tf.GradientTape() as gen_tape:

        G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0], [data_batch[:,2:4]]) # send only pneumonia
        G_fake_first= copy.deepcopy(G_fake)
        fake_batch = tf.concat(axis=1, values=[G_fake['covid_19'], G_fake['pneum']])
        encoded_fake_image = G_fake['Rxray']
        fakebatch_wimg = tf.concat(axis=1, values=[fake_batch, encoded_fake_image])

        # for P(C, Rxray, Pn)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        l3 = -tf.reduce_mean(D_fake)

        # #P(Covid)
        D_fake = discriminators['covid_19']([fake_batch[:,0:2]], training=True)
        l1=  -tf.reduce_mean(D_fake)

        G_loss = l3 + l1

        print(f'P(C,Rxray, Pneum)+P(C) : G_loss--->  {G_loss}')

    grad1, grad2 = gen_tape.gradient(G_loss, [label_generators['covid_19'].trainable_variables, label_generators['pneum'][1].trainable_variables])
    G_optimizers['covid_19'].apply_gradients(zip(grad1, label_generators['covid_19'].trainable_variables))
    G_optimizers['pneum'].apply_gradients(zip(grad2, label_generators['pneum'][1].trainable_variables))

    # -----------------------------------------------------------------------------
    print('Training Discriminator')

    #
    resized_images = tf.image.resize(image_batch, (64, 64))
    input_img, input_label, conditional_input = label_generators['Rxray'].conditional_input([resized_images, tf.reshape(data_batch[:,0], [-1,1])])
    encoded = label_generators['Rxray'].encoder(conditional_input, label_generators['Rxray'].latent_dim, is_train=False)
    z_mean, z_log_var = tf.split(encoded, num_or_size_splits=2, axis=1)
    z_cond = label_generators['Rxray'].reparametrization(z_mean, z_log_var, input_label)
    encoded_real_image= z_cond
    databatch_wimg= tf.concat([data_batch, encoded_real_image], 1)


    # ****
    G_fake = get_generated_labels(Exp, label_generators, {}, data_batch.shape[0], [data_batch[:,2:4]])
    G_fake_second= copy.deepcopy(G_fake)

    fake_batch = tf.concat(axis=1,values=[G_fake['covid_19'], G_fake['pneum']])
    encoded_fake_image = G_fake['Rxray']
    fakebatch_wimg = tf.concat(axis=1,values=[fake_batch, encoded_fake_image])


    # ****
    with tf.GradientTape() as disc_tape:
        D_real = discriminators['H2']([databatch_wimg], training=True)
        D_fake = discriminators['H2']([fakebatch_wimg], training=True)
        penalty = penalty_calculation(discriminators['H2'], databatch_wimg, fakebatch_wimg)
        D_loss =  tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    gradients_of_discriminator = disc_tape.gradient(D_loss, discriminators['H2'].trainable_variables)
    D_optimizer['H2'].apply_gradients(zip(gradients_of_discriminator, discriminators['H2'].trainable_variables))

    with tf.GradientTape() as covid_tape:
        D_real = discriminators['covid_19']([data_batch[:,0:2]], training=True)
        D_fake = discriminators['covid_19']([fake_batch[:, 0:2]], training=True)
        penalty = penalty_calculation(discriminators['covid_19'], data_batch[:,0:2], fake_batch[:, 0:2])
        D_loss = tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)
    gradients_of_discriminator = covid_tape.gradient(D_loss, discriminators['covid_19'].trainable_variables)
    D_optimizer['covid_19'].apply_gradients(zip(gradients_of_discriminator, discriminators['covid_19'].trainable_variables))


    # -----------------------------------------------------------------------------
    print('Training XrayGans')

    # imggen_loss, imgdis_loss= soft_trainXray(Exp, G_fake_first, G_fake_second, image_batch, original_databatch,  discriminatorsMech['xray'], label_generators['combined'], data_batch.shape[0])
    imggen_loss,imgdis_loss=-1,-1
    return G_loss, D_loss, imggen_loss, imgdis_loss,  G_fake_first['xray'][0].numpy()





def trainloop(Exp, cur_hnodes, label_generators, G_optimizers, discriminators, D_optimizers, train_dataset):
    iteration=0

    two_batches =[]
    for img_batch, label_batch in zip(train_dataset['img'], train_dataset['labels']):
        batch1= tf.reshape(label_batch['covid_19'], [-1,1])
        batch2 = tf.reshape(label_batch['pneumonia'], [-1, 1])
        udata_batch = tf.concat(axis=1, values=[batch1, batch2])
        image_batch= img_batch   # normalized during data loading

        if len(two_batches)<2:
            two_batches.append(udata_batch)



        G_loss, D_loss, igen_loss, idis_loss, fake_img= do_train(Exp, label_generators, discriminators, G_optimizers, D_optimizers, udata_batch, image_batch)

        print('Epoch [%d/%d], Step [%d/%d],' % (
            Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, len(train_dataset['labels'])),
              'mechanism: ', cur_hnodes, ' D_loss: %.4f, G_loss: %.4f' % (D_loss.numpy(), G_loss.numpy()))

        print(f'XrayImg loss, igen_loss:{igen_loss} idis_loss:{idis_loss}')

        print('Reduced temperature:',Exp.Temperature)

        iteration+=1




    tot_iter = Exp.curr_epoochs * len(train_dataset) + iteration
    Exp.anneal_temperature(tot_iter)

    print("--->", Exp.curr_epoochs)
    if Exp.curr_epoochs % 1 == 0:
        udata_batch= tf.concat(axis=0, values=two_batches)
        test_size = udata_batch.shape[0]  #calculating probablities based on 1k samples
        enc = OneHotEncoder()
        enc.fit(udata_batch)
        udata_batch = enc.transform(udata_batch).toarray()
        r_joint_prob, r_cond_list, r_covid_prob, r_pneum_prob = calculate_joint(Exp, udata_batch)


        G_fake = get_generated_labels(Exp, label_generators, {}, test_size, [udata_batch[:,2:4]])
        fake_batch = tf.concat(axis=1, values=[G_fake['covid_19'], G_fake['pneum']])
        f_joint_prob, f_cond_list, f_covid_prob, f_pneum_prob = calculate_joint(Exp, fake_batch)



        # P(covid,pneum)
        print(f'Real prob:joint_prob:{r_joint_prob}')
        print(f'Fake prob: joint_prob:{f_joint_prob}')
        obs_tvd = calculate_TVD(f_joint_prob, r_joint_prob, doPrint=False)
        Exp.tvd_diff['joint'].append(round(obs_tvd, 4))

        # P(pneum|covid)
        print(f'Real prob: P(pneum|covid=0) :{r_cond_list[0]}')
        print(f'Fake prob: P(pneum|covid=0):{f_cond_list[0]}')
        obs_tvd = calculate_TVD(f_cond_list[0], r_cond_list[0], doPrint=False)
        Exp.tvd_diff['cond_cov0'].append(round(obs_tvd, 4))

        print(f'Real prob:P(pneum|covid=1):{r_cond_list[1]}')
        print(f'Fake prob: P(pneum|covid=1):{f_cond_list[1]}')
        obs_tvd = calculate_TVD(f_cond_list[1], r_cond_list[1], doPrint=False)
        Exp.tvd_diff['cond_cov1'].append(round(obs_tvd, 4))

        # P(covid)
        print(f'Real prob:covid_prob:{r_covid_prob}')
        print(f'Fake prob: covid_prob:{f_covid_prob}')
        obs_tvd = calculate_TVD(f_covid_prob, r_covid_prob, doPrint=False)
        Exp.tvd_diff['covid'].append(round(obs_tvd, 4))

        # P(Pneum)
        print(f'Real prob:pneum_prob:{r_pneum_prob}')
        print(f'Fake prob: pneum_prob:{f_pneum_prob}')
        obs_tvd = calculate_TVD(f_pneum_prob, r_pneum_prob, doPrint=False)
        Exp.tvd_diff['pneum'].append(round(obs_tvd, 4))

        # ATE
        # intervention  do(covid=1)
        G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 1}, test_size, [udata_batch[:,2:4]])
        intv_pneum = G_fake['pneum']
        intv_pneum = tf.math.argmax(intv_pneum, axis=1)
        intv_pneum = tf.reshape(intv_pneum, [-1, 1])
        intv_pneum_prob_do1 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())
        print('P(Y|do(X=1)', intv_pneum_prob_do1)

        # intervention  do(covid=0)
        G_fake = get_generated_labels(Exp, label_generators, {'covid_19': 0}, test_size, [udata_batch[:,2:4]])
        intv_pneum = G_fake['pneum']
        intv_pneum = tf.math.argmax(intv_pneum, axis=1)
        intv_pneum = tf.reshape(intv_pneum, [-1, 1])
        intv_pneum_prob_do0 = get_joint_distributions_from_samples(['pneum'], [2], intv_pneum.numpy())
        print('P(Y|do(X=0)', intv_pneum_prob_do0)

        ATE = intv_pneum_prob_do1[tuple([1])] - intv_pneum_prob_do0[tuple([1])]
        Exp.tvd_diff['ATE'].append(ATE)

        ll = -min(10, len(list(Exp.tvd_diff.values())[0]))
        for dist in Exp.tvd_diff:
            print("###", dist, " loss%:", [round(val, 4) for val in Exp.tvd_diff[dist][ll:]])

        # path = ".././SaveDir/tvd"
        path = f"/SaveDir/{Exp.exp_name}/tvd"
        os.makedirs(path, exist_ok=True)

        for dist in Exp.tvd_diff:
            np.save(f'{path}/{dist}.npy', np.array(Exp.tvd_diff[dist]))

        print('files saved')

    if Exp.curr_epoochs > 200 and Exp.tvd_diff['joint'][-1] < 0.20:
    # if Exp.curr_epoochs % 1 == 0:
            root = "/XrayImageExperiment"

            label_generators['covid_19'].save_weights(
                f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}/covid_19_gen/gen')
            label_generators['pneum'][1].save_weights(
                f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}/pneumonia_gen/gen')


            for var in discriminators:
                discriminators[var].save_weights(
                    f'{root}/checkpoints/{Exp.exp_name}/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}/{var}_disc/disc')

            print('model saved!!!')

            os.makedirs(f'{path}/intv1_pneum', exist_ok=True)
            with open(f'{path}/intv1_pneum/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}.pkl', 'wb') as f:
                pickle.dump(intv_pneum_prob_do1, f)


            os.makedirs(f'{path}/intv0_pneum', exist_ok=True)
            with open(f'{path}/intv0_pneum/Epoch{Exp.curr_epoochs}_{Exp.tvd_diff["joint"][-1]}.pkl', 'wb') as f:
                pickle.dump(intv_pneum_prob_do0, f)

    plt.imshow(fake_img)
    plt.savefig(f'./fake_images/epoch{Exp.curr_epoochs%5}.png')
    # plt.show()





def define_acgan(Exp, latent_dim = 100, adam_lr = 1e-5, adam_beta_1 = 0.5):
    """
    This function defines and compiles the discriminator, generator, and combined models for a GAN.
    The models are compiled with the Adam optimizer, using the defined learning rate and beta 1 values.

    Parameters:
    latent_dim (int, optional): The dimension of the latent space that will be used as input to the generator.

    Returns:
    combined (Model): The combined generator and discriminator model.
    dis (Model): The discriminator model.
    gen (Model): The generator model.
    """

    ### STEP1: Define and train/evaluate the DISCRIMINATOR model ###

    # Define the discriminator model and Adam optimizer \
        # with the defined learning rate and beta 1 values and \
            # loss function as binary crossentropy and sparse categorical crossentropy

    dis = xrayDiscriminator(input_shape=(Exp.IMAGE_SIZE, Exp.IMAGE_SIZE, 3))
    epoch = 288
    filename = f"/SaveDir/params_discriminator_epoch_{epoch}.hdf5"
    dis.load_weights(filename)

    # dis.summary()
    # We define two losses, the first for the real/fake binary output and the second for the class label output
    dis.compile(
        optimizer=Adam(learning_rate=adam_lr, beta_1=adam_beta_1),
        loss={'source':'binary_crossentropy', 'auxiliary':'sparse_categorical_crossentropy'}
        # , metrics = {'source':BinaryAccuracy(threshold=.5), 'auxiliary':'sparse_categorical_accuracy'}
    )

    ### STEP2: Define and input the GENERATOR model ###

    # Define the generator model
    gen = xrayGenerator(latent_dim=100, n_classes=3)
    epoch = 288
    filename = f"/SaveDir/params_generator_epoch_{epoch}.hdf5"
    gen.load_weights(filename)

    # gen.summary()
    # Compile the discriminator model with the Adam optimizer, \
        # using binary crossentropy and sparse categorical crossentropy as loss functions

    # We will note use this function to train/evaluate the generator model
    # but we will use combined model with the discriminator model disabled
    # to train and evaluate the generator model
    gen.compile(optimizer=Adam(learning_rate=adam_lr, beta_1=adam_beta_1),
                      loss='binary_crossentropy')

    # Define inputs for the generator model (noise vector and image class)
    latent = Input(shape=(latent_dim, ), name='latent_noise')
    image_class = Input(shape=(1,), name='image_class')
    print(image_class.dtype)
    # Get a fake image
    fake_img = gen([latent, image_class])
    print('fake image: ', fake_img.shape)

    ### STEP3: Train/evaluate the GENERATOR model by using the COMBINED model with the DISCRIMINATOR model disabled ###

    # Disable training of the discriminator model in the combined model
    dis.trainable = False

    # Get the fake and auxiliary classes outputs from the discriminator model, \
    # using the generated fake image as input
    fake, aux = dis(fake_img)

    # Define the combined model, which takes the noise vector and image class as input, \
        # and outputs the fake and auxiliary classes
    combined = Model(inputs=[latent, image_class],
                            outputs=[fake, aux],
                            name='ACGAN')

    # Function used to train/test the generator model
    combined.compile(
        optimizer=Adam(learning_rate=adam_lr, beta_1=adam_beta_1),
        loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
    )

    combined.summary()

    # Return the combined model, discriminator model, and generator model
    return combined, dis, gen




if __name__ == '__main__':


    args = sys.argv

    if len(args) == 1:
        exp_name = 'xrayGantest'
    else:
        exp_name = args[1]

    Exp = Experiment(set_Xray,
                     exp_name=exp_name,
                     NOISE_DIM=64,
                     CONF_NOISE_DIM=64,
                     Temperature=1,
                     temp_min=0.01,
                     ANNEAL_RATE=0.0003,
                     CRITIC_ITERATIONS=1,
                     LAMBDA_GP=10,
                     batch_size=200,
                     ENCODED_DIM=100,
                     Data_intervs=[{}],
                     num_epochs=1000,
                     IMAGE_SIZE=112,
                     new_experiment=True
                     )

    print('Experiment name:', Exp.exp_name)
    Exp.tvd_diff= {'joint':[], 'cond_cov0':[],  'cond_cov1':[], 'covid':[], 'pneum':[], 'ATE':[] }
    dag_name = Exp.Complete_DAG_desc + ".txt"

    root = "/Dataset/COVIDx-splitted-resized-112"
    data = pd.read_csv(f'{root}/train_dataset.csv')

    ##### Image data load
    image_data = []
    image_data, valid_id = load_dataset(Exp.batch_size, Exp.IMAGE_SIZE, root, data, split='train')


    ##### label data load
    label_data = data[["covid_19", "pneumonia"]].iloc[valid_id]

    # replacingn 400 rows with covid=1 , pneum=1 with covid=1, pneum=0  keeping the same image
    idlist = label_data.index[label_data['covid_19'] == 1].tolist()
    ret = random.sample(idlist, 500)
    label_data.loc[ret, ["pneumonia"]] = 0
    #

    label_dataset = tf.data.Dataset.from_tensor_slices(dict(label_data)).batch(Exp.batch_size)
    train_dataset = {'img': image_data, 'labels': label_dataset}

    # learning rate
    initial_learning_rate = 5 * 1e-4
    final_learning_rate = 1e-4
    learning_rate_decay_factor = (final_learning_rate / initial_learning_rate) ** (1 / Exp.num_epochs)
    steps_per_epoch = int(len(valid_id) / Exp.batch_size)  # dataset size/batch size
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=initial_learning_rate,
        decay_steps=steps_per_epoch,
        decay_rate=learning_rate_decay_factor,
        staircase=True)
    Exp.learning_rate = lr_schedule



#     ---
    # Models
    cur_hnodes = {"H2": ["covid_19", "pneum"]}
    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)
    discriminatorsMech, doptimizersMech = get_discriminators(Exp)  #


    combined, dis, gen = define_acgan(Exp, latent_dim = 100)
    label_generators['xray'] = gen
    label_generators['combined']= combined
    discriminatorsMech['xray'] = dis

    for epoch in tqdm(range(Exp.num_epochs)):
        Exp.curr_epoochs = epoch
        trainloop(Exp, cur_hnodes, label_generators, optimizersMech, discriminatorsMech, doptimizersMech, train_dataset)



# Modifications done:
# Added 500 dummy covid=1 pnumonia=0 samples
# two batches of test samples during testing
# Normalized during data loading
# Shuffled
# Pneumonia input to gan
# XrayGAN: Trainig; Used two fake batches for GAN training - Not training rn
# Encoder my pre-trained: Not training